import torch
import logging
import os
import sys
import numpy as np
from src.eval import (
    load_original_model_and_inputs,
    time_execution_with_cuda_event,
    get_timing_stats,
    set_seed,
    fetch_ref_arch_from_problem_id,
)
from src.dataset import construct_problem_dataset_from_problem_dir
import os, sys
import logging
import json

device = torch.device("cuda:0")

REPO_TOP_PATH = os.path.abspath(
    os.path.join(
        os.path.dirname(__file__),
        "..",
    )
)
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")


assert torch.cuda.get_device_capability() >= (7, 0), "torch.compile is not supported on this device."

def get_toy_torch_compile_fn_and_inputs():
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2
    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
    return fn, inputs


def inspect_torch_compile(fn, inputs, output_dir="results/triton_code", filename="optimized_kernel"):
    """
    Benchmark a torch.compile'd function by viewing dynamo tracing, traced graph,
    fusion decisions and generated code.
    
    Args:
        fn: The compiled function to benchmark
        inputs: Tuple of input tensors to the function
        output_dir: Directory to save generated code
    """
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()

    separator("Dynamo Tracing")
    # View dynamo tracing
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph") 
    # View traced graph
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
    # View fusion decisions
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
    # View output code generated by inductor
    os.makedirs(output_dir, exist_ok=True)
    
    # Create a custom logging handler to capture the output
    class OutputCodeHandler(logging.Handler):
        def __init__(self, file):
            super().__init__()
            self.file = file
        
        def emit(self, record):
            self.file.write(self.format(record) + '\n')

    with open(f"{output_dir}/{filename}.py", "w") as f:
        # Set up logging handler
        handler = OutputCodeHandler(f)
        logging.getLogger("torch._inductor.codecache").addHandler(handler)
        
        torch._logging.set_logs(output_code=True)
        fn(*inputs)  # Run the function
        
        # Clean up handler
        logging.getLogger("torch._inductor.codecache").removeHandler(handler)

    separator("")
    
def fetch_ref_arch_from_level_problem_id(level_num, problem_id, with_name=False):
    PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num))
    dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
    return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name)

def inspect_torch_compile_triton(level_num, problem_id):
    ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
        level_num, problem_id, with_name=True
    )
    ref_arch_name = ref_arch_name.split("/")[-1]
    context = {}
    Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
        ref_arch_src, context
    )


def inspect_baseline_torch_compile(level_num, problem_id):
    """
    Get the triton code generated by torch compile for a particular problem
    """
    ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
        level_num, problem_id, with_name=True
    )

    ref_arch_name = ref_arch_name.split("/")[-1]
    context = {}
    Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
        ref_arch_src, context
    )

    try:
        with torch.no_grad():
            torch.cuda.synchronize(device=device)
            set_seed(42)
            inputs = get_inputs()
            set_seed(42)
            init_inputs = get_init_inputs()
            inputs = [
                x.cuda(device=device) if isinstance(x, torch.Tensor) else x
                for x in inputs
            ]
            init_inputs = [
                x.cuda(device=device) if isinstance(x, torch.Tensor) else x
                for x in init_inputs
            ]
            model = Model(*init_inputs)
            
            model = torch.compile(model)
            model = model.cuda(device=device)
            inspect_torch_compile(model, inputs, output_dir="results/triton_code", filename=f"level{level_num}_problem{problem_id}_triton")
    except Exception as e:
        print(f"[Eval] Error in Inspecting Torch Compile: {e}")


if __name__ == "__main__":
    # fn, inputs = get_toy_torch_compile_fn_and_inputs()

    inspect_baseline_torch_compile(2, 43)


